{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "35d1db70-b03b-4225-b915-9bd7b7d36536",
   "metadata": {},
   "source": [
    "https://scikit-learn.org/1.1/modules/svm.html#regression"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "611b9e55-70e6-4c32-9c99-caa8f1aa3660",
   "metadata": {},
   "source": [
    "## Traning and Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6afd5aa9-9da0-484b-b50a-6e24637b9b1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((442, 10), (442,))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.datasets import load_diabetes\n",
    "X, y = load_diabetes(return_X_y=True)\n",
    "(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3be6782a-f115-4182-ac26-ef9c0d4f2e11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 0.03807591,  0.05068012,  0.06169621,  0.02187239, -0.0442235 ,\n",
       "        -0.03482076, -0.04340085, -0.00259226,  0.01990749, -0.01764613]),\n",
       " 151.0)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(X[0], y[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "324d5e22-cd46-4d24-b246-dd1c7c7853f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([-0.00188202, -0.04464164, -0.05147406, -0.02632753, -0.00844872,\n",
       "        -0.01916334,  0.07441156, -0.03949338, -0.06833155, -0.09220405]),\n",
       " 75.0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(X[1], y[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b40a740-e8d1-4dd8-91a3-13c8dd6fcb68",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n",
      "  warnings.warn(\"Setuptools is replacing distutils.\")\n",
      "2023/03/29 13:13:56 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.\n",
      "Registered model 'da_svr' already exists. Creating a new version of this model...\n",
      "2023/03/29 13:13:56 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_svr, version 2\n",
      "Created version '2' of model 'da_svr'.\n",
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/liga/mlflow/logger.py:137: UserWarning: value of rikai.output.schema is None or empty and will not be populated to MLflow\n",
      "  warnings.warn(\n",
      "2023/03/29 13:13:58 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.\n",
      "Registered model 'da_nusvr' already exists. Creating a new version of this model...\n",
      "2023/03/29 13:13:58 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_nusvr, version 2\n",
      "Created version '2' of model 'da_nusvr'.\n",
      "2023/03/29 13:13:59 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.\n",
      "Registered model 'da_linear_svr' already exists. Creating a new version of this model...\n",
      "2023/03/29 13:14:00 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_linear_svr, version 2\n",
      "Created version '2' of model 'da_linear_svr'.\n"
     ]
    }
   ],
   "source": [
    "import getpass\n",
    "\n",
    "import mlflow\n",
    "from liga.sklearn.mlflow import log_model\n",
    "from sklearn import svm\n",
    "\n",
    "\n",
    "mlflow_tracking_uri = \"sqlite:///mlruns.db\"\n",
    "mlflow.set_tracking_uri(mlflow_tracking_uri)\n",
    "\n",
    "# train a model\n",
    "with mlflow.start_run() as run:\n",
    "    ####\n",
    "    # Part 1: Train the model and register it on MLflow\n",
    "    ####\n",
    "    \n",
    "    model_svr = svm.SVR(epsilon=0.3).fit(X, y)\n",
    "    model_nusvr = svm.NuSVR().fit(X, y)\n",
    "    model_l_svr = svm.LinearSVR().fit(X, y)\n",
    "    \n",
    "    svr_name = f\"{getpass.getuser()}_svr\"\n",
    "    nusvr_name = f\"{getpass.getuser()}_nusvr\"\n",
    "    l_svr_name = f\"{getpass.getuser()}_linear_svr\"\n",
    "    \n",
    "    log_model(model_svr, registered_model_name=svr_name)\n",
    "    log_model(model_nusvr, registered_model_name=nusvr_name)\n",
    "    log_model(model_l_svr, registered_model_name=l_svr_name)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adad1843-8e9b-4dc5-89ac-892ee670a740",
   "metadata": {},
   "source": [
    "## Apply the model on the large scale dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d73eb3f5-c3f8-47d1-9e3d-04da4ca39e11",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-03-29 13:14:00,112 INFO Rikai (__init__.py:121): setting spark.sql.extensions to net.xmacs.liga.spark.RikaiSparkSessionExtensions\n",
      "2023-03-29 13:14:00,113 INFO Rikai (__init__.py:121): setting spark.driver.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:14:00,114 INFO Rikai (__init__.py:121): setting spark.executor.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:14:00,115 INFO Rikai (__init__.py:121): setting spark.jars to https://github.com/komprenilo/liga/releases/download/v0.3.0/liga-spark321-assembly_2.12-0.3.0.jar\n",
      "23/03/29 13:14:01 WARN Utils: Your hostname, tubi resolves to a loopback address: 127.0.1.1; using 192.168.31.32 instead (on interface wlp0s20f3)\n",
      "23/03/29 13:14:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "23/03/29 13:14:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
      "23/03/29 13:14:06 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n",
      "23/03/29 13:14:06 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+------+-----------------------+-------+\n",
      "|name |plugin|uri                    |options|\n",
      "+-----+------+-----------------------+-------+\n",
      "|svr  |      |mlflow:///da_svr       |       |\n",
      "|nusvr|      |mlflow:///da_nusvr     |       |\n",
      "|l_svr|      |mlflow:///da_linear_svr|       |\n",
      "+-----+------+-----------------------+-------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from example import spark\n",
    "from liga.mlflow import CONF_MLFLOW_TRACKING_URI\n",
    "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"false\")\n",
    "spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)\n",
    "spark.sql(f\"\"\"\n",
    "CREATE OR REPLACE MODEL svr LOCATION 'mlflow:///{svr_name}';\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "spark.sql(f\"\"\"\n",
    "CREATE OR REPLACE MODEL nusvr LOCATION 'mlflow:///{nusvr_name}';\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "spark.sql(f\"\"\"\n",
    "CREATE OR REPLACE MODEL l_svr LOCATION 'mlflow:///{l_svr_name}';\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "spark.sql(\"show models\").show(10, vertical=False, truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "658686cb-a5db-4dcf-ac34-4323d00b36bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- svr: float (nullable = true)\n",
      " |-- nusvr: float (nullable = true)\n",
      " |-- l_svr: float (nullable = true)\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>svr</th>\n",
       "      <th>nusvr</th>\n",
       "      <th>l_svr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>110.177177</td>\n",
       "      <td>110.177177</td>\n",
       "      <td>110.177177</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          svr       nusvr       l_svr\n",
       "0  110.177177  110.177177  110.177177"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = spark.sql(f\"\"\"\n",
    "select\n",
    "  ML_PREDICT(svr, array(0.03807591,  0.05068012,  0.06169621,  0.02187239, -0.0442235 ,\n",
    "        -0.03482076, -0.04340085, -0.00259226,  0.01990749, -0.01764613)) as svr,\n",
    "  ML_PREDICT(nusvr, array(0.03807591,  0.05068012,  0.06169621,  0.02187239, -0.0442235 ,\n",
    "        -0.03482076, -0.04340085, -0.00259226,  0.01990749, -0.01764613)) as nusvr,\n",
    "  ML_PREDICT(l_svr, array(0.03807591,  0.05068012,  0.06169621,  0.02187239, -0.0442235 ,\n",
    "        -0.03482076, -0.04340085, -0.00259226,  0.01990749, -0.01764613)) as l_svr\n",
    "        \n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "result.printSchema()\n",
    "result.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "25d20ae5-6b4b-448b-8f21-0a26dc72330f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>svr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>106.438087</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          svr\n",
       "0  106.438087"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spark.sql(f\"\"\"\n",
    "select\n",
    "  ML_PREDICT(svr, array(-0.00188202, -0.04464164, -0.05147406, -0.02632753, -0.00844872,\n",
    "        -0.01916334,  0.07441156, -0.03949338, -0.06833155, -0.09220405)) as svr\n",
    "\"\"\"\n",
    ").toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0df9f3f4-e79e-4b32-bb85-e94ad78500ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>nusvr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>106.438087</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        nusvr\n",
       "0  106.438087"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "spark.sql(\"\"\"\n",
    "select  ML_PREDICT(nusvr, array(-0.00188202, -0.04464164, -0.05147406, -0.02632753, -0.00844872,\n",
    "        -0.01916334,  0.07441156, -0.03949338, -0.06833155, -0.09220405)) as nusvr\n",
    "\"\"\").toPandas()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "77987f10-9215-492b-acd0-3243ce6ac294",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>l_svr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>106.438087</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        l_svr\n",
       "0  106.438087"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spark.sql(\"\"\"\n",
    "select ML_PREDICT(l_svr, array(-0.00188202, -0.04464164, -0.05147406, -0.02632753, -0.00844872,\n",
    "        -0.01916334,  0.07441156, -0.03949338, -0.06833155, -0.09220405)) as l_svr\n",
    "\"\"\").toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eacbe16a-c846-4ac3-bff2-673c846cd6fa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
